{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-09-24T07:02:17.301900Z",
     "start_time": "2017-09-24T07:02:16.445875Z"
    }
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "\n",
    "import re\n",
    "import collections\n",
    "import itertools\n",
    "import bcolz\n",
    "import pickle\n",
    "sys.path.append('../../lib')\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import gc\n",
    "import random\n",
    "import smart_open\n",
    "import h5py\n",
    "import csv\n",
    "import json\n",
    "import functools\n",
    "import time\n",
    "import string\n",
    "\n",
    "import datetime as dt\n",
    "from tqdm import tqdm_notebook as tqdm\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "random_state_number = 967898"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-09-24T07:02:18.146019Z",
     "start_time": "2017-09-24T07:02:17.303484Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['/gpu:0', '/gpu:1']"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.python.client import device_lib\n",
    "def get_available_gpus():\n",
    "    local_device_protos = device_lib.list_local_devices()\n",
    "    return [x.name for x in local_device_protos if x.device_type == 'GPU']\n",
    "\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth=True\n",
    "sess = tf.Session(config=config)\n",
    "get_available_gpus()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-09-24T07:02:18.260494Z",
     "start_time": "2017-09-24T07:02:18.147847Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using matplotlib backend: TkAgg\n",
      "Populating the interactive namespace from numpy and matplotlib\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/bicepjai/Programs/anaconda3/envs/dsotc-c3/lib/python3.6/site-packages/IPython/core/magics/pylab.py:160: UserWarning: pylab import has clobbered these variables: ['random']\n",
      "`%matplotlib` prevents importing * from pylab and numpy\n",
      "  \"\\n`%matplotlib` prevents importing * from pylab and numpy\"\n"
     ]
    }
   ],
   "source": [
    "%pylab\n",
    "%matplotlib inline\n",
    "%load_ext line_profiler\n",
    "%load_ext memory_profiler\n",
    "%load_ext autoreload"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-09-24T07:02:18.265555Z",
     "start_time": "2017-09-24T07:02:18.262104Z"
    }
   },
   "outputs": [],
   "source": [
    "pd.options.mode.chained_assignment = None\n",
    "pd.options.display.max_columns = 999\n",
    "color = sns.color_palette()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-09-24T07:02:37.886019Z",
     "start_time": "2017-09-24T07:02:19.333371Z"
    }
   },
   "outputs": [],
   "source": [
    "store = pd.HDFStore('../data_prep/processed/stage1/data_frames.h5')\n",
    "train_df = store['train_df']\n",
    "test_df = store['test_df']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-09-24T07:02:38.250337Z",
     "start_time": "2017-09-24T07:02:37.887282Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style>\n",
       "    .dataframe thead tr:only-child th {\n",
       "        text-align: right;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ID</th>\n",
       "      <th>Gene</th>\n",
       "      <th>Variation</th>\n",
       "      <th>Class</th>\n",
       "      <th>Sentences</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>[fam58a]</td>\n",
       "      <td>[truncating, mutations]</td>\n",
       "      <td>1</td>\n",
       "      <td>[[cyclin-dependent, kinases, , cdks, , regulat...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>[cbl]</td>\n",
       "      <td>[w802*]</td>\n",
       "      <td>2</td>\n",
       "      <td>[[abstract, background, non-small, cell, lung,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>[cbl]</td>\n",
       "      <td>[q249e]</td>\n",
       "      <td>2</td>\n",
       "      <td>[[abstract, background, non-small, cell, lung,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>[cbl]</td>\n",
       "      <td>[n454d]</td>\n",
       "      <td>3</td>\n",
       "      <td>[[recent, evidence, has, demonstrated, that, a...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>[cbl]</td>\n",
       "      <td>[l399v]</td>\n",
       "      <td>4</td>\n",
       "      <td>[[oncogenic, mutations, in, the, monomeric, ca...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   ID      Gene                Variation  Class  \\\n",
       "0   0  [fam58a]  [truncating, mutations]      1   \n",
       "1   1     [cbl]                  [w802*]      2   \n",
       "2   2     [cbl]                  [q249e]      2   \n",
       "3   3     [cbl]                  [n454d]      3   \n",
       "4   4     [cbl]                  [l399v]      4   \n",
       "\n",
       "                                           Sentences  \n",
       "0  [[cyclin-dependent, kinases, , cdks, , regulat...  \n",
       "1  [[abstract, background, non-small, cell, lung,...  \n",
       "2  [[abstract, background, non-small, cell, lung,...  \n",
       "3  [[recent, evidence, has, demonstrated, that, a...  \n",
       "4  [[oncogenic, mutations, in, the, monomeric, ca...  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style>\n",
       "    .dataframe thead tr:only-child th {\n",
       "        text-align: right;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ID</th>\n",
       "      <th>Gene</th>\n",
       "      <th>Variation</th>\n",
       "      <th>Sentences</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>[acsl4]</td>\n",
       "      <td>[r570s]</td>\n",
       "      <td>[[2, this, mutation, resulted, in, a, myelopro...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>[naglu]</td>\n",
       "      <td>[p521l]</td>\n",
       "      <td>[[abstract, the, large, tumor, suppressor, 1, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>[pah]</td>\n",
       "      <td>[l333f]</td>\n",
       "      <td>[[vascular, endothelial, growth, factor, recep...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>[ing1]</td>\n",
       "      <td>[a148d]</td>\n",
       "      <td>[[inflammatory, myofibroblastic, tumor, , imt,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>[tmem216]</td>\n",
       "      <td>[g77a]</td>\n",
       "      <td>[[abstract, retinoblastoma, is, a, pediatric, ...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   ID       Gene Variation                                          Sentences\n",
       "0   0    [acsl4]   [r570s]  [[2, this, mutation, resulted, in, a, myelopro...\n",
       "1   1    [naglu]   [p521l]  [[abstract, the, large, tumor, suppressor, 1, ...\n",
       "2   2      [pah]   [l333f]  [[vascular, endothelial, growth, factor, recep...\n",
       "3   3     [ing1]   [a148d]  [[inflammatory, myofibroblastic, tumor, , imt,...\n",
       "4   4  [tmem216]    [g77a]  [[abstract, retinoblastoma, is, a, pediatric, ..."
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "display(train_df.head())\n",
    "display(test_df.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-09-24T07:02:38.358654Z",
     "start_time": "2017-09-24T07:02:38.251731Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(352220, 352220)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with open('../data_prep/processed/stage1/vocab_words_wordidx.pkl', 'rb') as f:\n",
    "    (vocab_words, vocab_wordidx) = pickle.load(f)\n",
    "vocab_size = len(vocab_words)\n",
    "vocab_size, len(vocab_wordidx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train-Test Split and Data Prep"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-09-24T07:02:38.402860Z",
     "start_time": "2017-09-24T07:02:38.359787Z"
    }
   },
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split, StratifiedShuffleSplit\n",
    "from sklearn.metrics import make_scorer, f1_score, precision_score, accuracy_score, log_loss\n",
    "f1_scorer = make_scorer(f1_score, average=\"macro\")\n",
    "precision_scorer = make_scorer(precision_score, average=\"macro\")\n",
    "accuracy_scorer = make_scorer(accuracy_score, average=\"macro\")\n",
    "log_loss_scorer = make_scorer(log_loss)\n",
    "\n",
    "from sklearn.metrics import classification_report, confusion_matrix\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.model_selection import cross_val_score, GridSearchCV, RandomizedSearchCV\n",
    "\n",
    "from sklearn.feature_extraction.text import CountVectorizer\n",
    "from sklearn.feature_extraction.text import TfidfTransformer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-09-24T07:02:40.387283Z",
     "start_time": "2017-09-24T07:02:38.404097Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style>\n",
       "    .dataframe thead tr:only-child th {\n",
       "        text-align: right;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Text</th>\n",
       "      <th>Class</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[fam58a, truncating, mutations, cyclin-depende...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>[cbl, w802*, abstract, background, non-small, ...</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>[cbl, q249e, abstract, background, non-small, ...</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>[cbl, n454d, recent, evidence, has, demonstrat...</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>[cbl, l399v, oncogenic, mutations, in, the, mo...</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                Text  Class\n",
       "0  [fam58a, truncating, mutations, cyclin-depende...      1\n",
       "1  [cbl, w802*, abstract, background, non-small, ...      2\n",
       "2  [cbl, q249e, abstract, background, non-small, ...      2\n",
       "3  [cbl, n454d, recent, evidence, has, demonstrat...      3\n",
       "4  [cbl, l399v, oncogenic, mutations, in, the, mo...      4"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "train_df.Sentences = train_df.Sentences.apply(lambda ll: list(itertools.chain.from_iterable(ll)))\n",
    "all_text_train_df = pd.DataFrame()\n",
    "all_text_train_df[\"Text\"] = train_df.Gene + train_df.Variation + train_df.Sentences\n",
    "all_text_train_df[\"Class\"] = train_df.Class\n",
    "display(all_text_train_df.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-09-24T07:02:40.394982Z",
     "start_time": "2017-09-24T07:02:40.388426Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(2988,) (2988,)\n",
      "(333,) (333,)\n"
     ]
    }
   ],
   "source": [
    "x_train, x_test, y_train, y_test = train_test_split(all_text_train_df.Text,all_text_train_df.Class,\n",
    "                                                   test_size=0.10, random_state=random_state_number,\n",
    "                                                   stratify=all_text_train_df.Class)\n",
    "\n",
    "print(x_train.shape, y_train.shape)\n",
    "print(x_test.shape, y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-09-24T07:02:40.398306Z",
     "start_time": "2017-09-24T07:02:40.396188Z"
    }
   },
   "outputs": [],
   "source": [
    "del all_text_train_df\n",
    "del train_df\n",
    "#del test_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-09-24T07:02:40.401606Z",
     "start_time": "2017-09-24T07:02:40.399373Z"
    }
   },
   "outputs": [],
   "source": [
    "cvec = CountVectorizer(vocabulary=vocab_wordidx)\n",
    "tfidf = TfidfTransformer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-09-24T07:02:59.378427Z",
     "start_time": "2017-09-24T07:02:40.402695Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(2988, 352220)\n",
      "(2988, 352220)\n"
     ]
    }
   ],
   "source": [
    "x_train = x_train.str.join(\" \")\n",
    "x_train_counts = cvec.fit_transform(x_train, y_train)\n",
    "print(x_train_counts.shape)\n",
    "x_train_tf = tfidf.fit_transform(x_train_counts)\n",
    "print(x_train_tf.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-09-24T07:03:01.650715Z",
     "start_time": "2017-09-24T07:02:59.379671Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(333, 352220)\n",
      "(333, 352220)\n"
     ]
    }
   ],
   "source": [
    "x_test = x_test.str.join(\" \")\n",
    "x_test_counts = cvec.fit_transform(x_test, y_test)\n",
    "print(x_test_counts.shape)\n",
    "x_test_tf = tfidf.fit_transform(x_test_counts)\n",
    "print(x_test_tf.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-09-24T07:03:01.655436Z",
     "start_time": "2017-09-24T07:03:01.652204Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "333"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(x_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-09-24T07:03:04.396372Z",
     "start_time": "2017-09-24T07:03:01.656714Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "197"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gc.collect()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# XGBoost"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "not able to run xgboost, system dies in the middle on 3 tries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2017-09-24T07:02:26.769Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/bicepjai/Programs/anaconda3/envs/dsotc-c3/lib/python3.6/site-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n",
      "  \"This module will be removed in 0.20.\", DeprecationWarning)\n"
     ]
    }
   ],
   "source": [
    "from xgboost import XGBClassifier\n",
    "from xgboost.core import XGBoostError\n",
    "from sklearn.model_selection import RandomizedSearchCV, GridSearchCV\n",
    "import scipy.stats as st"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2017-09-24T07:02:30.038Z"
    }
   },
   "outputs": [],
   "source": [
    "xgb_clfr = XGBClassifier(objective=\"multi:softprob\",n_estimators=500,  learning_rate=0.01, nthread=10)\n",
    "xgb_clfr.fit(x_train_tf.toarray(),y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CatBoost"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-09-24T06:41:56.274397Z",
     "start_time": "2017-09-24T06:41:56.222980Z"
    }
   },
   "outputs": [],
   "source": [
    "import catboost\n",
    "from catboost import CatBoostClassifier, CatboostError\n",
    "from sklearn.model_selection import RandomizedSearchCV, GridSearchCV\n",
    "from sklearn.model_selection import ParameterSampler\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import classification_report, confusion_matrix\n",
    "import scipy.stats as st"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-09-24T06:41:56.994999Z",
     "start_time": "2017-09-24T06:41:56.982465Z"
    }
   },
   "outputs": [],
   "source": [
    "cat_clfr = CatBoostClassifier(iterations=500, loss_function='MultiClass', \n",
    "                                thread_count=8, train_dir = \"temp_files\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-09-24T02:50:58.677060Z",
     "start_time": "2017-09-23T23:29:57.422923Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<catboost.core.CatBoostClassifier at 0x7fa2037f0cf8>"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cat_clfr.fit(x_train_tf.toarray(), y_train, \n",
    "               use_best_model=True,\n",
    "               eval_set=(x_test_tf.toarray(), y_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-09-24T04:16:42.407838Z",
     "start_time": "2017-09-24T04:16:42.297950Z"
    }
   },
   "outputs": [],
   "source": [
    "#cat_clfr.save_model(\"stage1/catboost_classifier\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-09-24T06:42:00.566386Z",
     "start_time": "2017-09-24T06:42:00.462955Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<catboost.core.CatBoostClassifier at 0x7f2564f8a0f0>"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cat_clfr.load_model(\"stage1/catboost_classifier\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-09-24T06:43:08.180246Z",
     "start_time": "2017-09-24T06:42:03.164082Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.587453613072\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/bicepjai/Programs/anaconda3/envs/dsotc-c3/lib/python3.6/site-packages/sklearn/metrics/classification.py:1135: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no predicted samples.\n",
      "  'precision', 'predicted', average, warn_for)\n"
     ]
    }
   ],
   "source": [
    "y_pred = cat_clfr.predict(x_train_tf.toarray())\n",
    "print(f1_score(y_train, y_pred, average=\"macro\"))"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "toc_cell": false,
   "toc_position": {
    "height": "789px",
    "left": "0px",
    "right": "1379px",
    "top": "52px",
    "width": "212px"
   },
   "toc_section_display": "block",
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
